{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch_geometric.datasets import AMiner\n",
    "from torch_geometric.nn import MetaPath2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MetaPath2Vec\n",
    "\n",
    "[paper](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf)  \n",
    "[code](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/metapath2vec.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the dataset\n",
    "path = osp.join('..', 'data', 'AMiner')\n",
    "dataset = AMiner(path)\n",
    "data = dataset[0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data(\n",
      "  edge_index_dict={\n",
      "    ('paper', 'written by', 'author')=[2, 9323605],\n",
      "    ('author', 'wrote', 'paper')=[2, 9323605],\n",
      "    ('paper', 'published in', 'venue')=[2, 3194405],\n",
      "    ('venue', 'published', 'paper')=[2, 3194405]\n",
      "  },\n",
      "  num_nodes_dict={\n",
      "    paper=3194405,\n",
      "    author=1693531,\n",
      "    venue=3883\n",
      "  },\n",
      "  y_dict={\n",
      "    author=[246678],\n",
      "    venue=[134]\n",
      "  },\n",
      "  y_index_dict={\n",
      "    author=[246678],\n",
      "    venue=[134]\n",
      "  }\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n",
      "tensor([[      0,       1,       2,  ..., 3194404, 3194404, 3194404],\n",
      "        [      0,       1,       2,  ...,    4393,   21681,  317436]])\n"
     ]
    }
   ],
   "source": [
    "print(type(data.edge_index_dict))\n",
    "print(data.edge_index_dict[('paper', 'written by', 'author')])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n",
      "{'paper': 3194405, 'author': 1693531, 'venue': 3883}\n"
     ]
    }
   ],
   "source": [
    "print(type(data.num_nodes_dict))\n",
    "print(data.num_nodes_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n",
      "tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
      "        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n",
      "        3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,\n",
      "        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7])\n"
     ]
    }
   ],
   "source": [
    "print(type(data.y_dict))\n",
    "print(data.y_dict[\"venue\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'dict'>\n",
      "tensor([1741, 2245,  111,  837, 2588, 2116, 2696, 3648, 3784,  313, 3414,  598,\n",
      "        2995, 2716, 1423,  783, 1902, 3132, 1753, 2748, 2660, 3182,  775, 3339,\n",
      "        1601, 3589,  156, 1145,  692, 3048,  925, 1587,  820, 1374, 3719,  819,\n",
      "         492, 3830, 2777, 3001, 3693,  517, 1808, 2353, 3499, 1763, 2372, 1030,\n",
      "         721, 2680, 3355, 1217, 3400, 1271, 1970, 1127,  407,  353, 1471, 1095,\n",
      "         477, 3701,   65, 1009, 1899, 1442, 2073, 3143, 2466,  289, 1996, 1070,\n",
      "        3871, 3695,  281, 3633,   50, 2642, 1925, 1285, 2587, 3814, 3582, 1873,\n",
      "        1339, 3450,  271, 2966,  453, 2638, 1354, 3211,  391, 1588, 3875, 2216,\n",
      "        2146, 3765, 2486,  661, 3367,  426,  750, 2158,  519,  230, 1677,  839,\n",
      "        2945, 1313, 1037, 2879, 2225, 3523, 1247,  448,  227, 3385,  529, 2849,\n",
      "        1584, 1229,  373, 2235, 1819, 1764, 3155, 2852, 2789, 3474, 1571, 2088,\n",
      "         208,  462])\n"
     ]
    }
   ],
   "source": [
    "print(type(data.y_index_dict))\n",
    "print(data.y_index_dict[\"venue\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# move the data to cpu or GPU\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "device = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the model\n",
    "\n",
    "metapath = [\n",
    "    ('author', 'wrote', 'paper'),\n",
    "    ('paper', 'published in', 'venue'),\n",
    "    ('venue', 'published', 'paper'),\n",
    "    ('paper', 'written by', 'author'),\n",
    "]\n",
    "\n",
    "\n",
    "model = MetaPath2Vec(data.edge_index_dict, \n",
    "                     embedding_dim=128,\n",
    "                     metapath=metapath,\n",
    "                     walk_length=5, \n",
    "                     context_size=3,\n",
    "                     walks_per_node=3,\n",
    "                     num_negative_samples=1,\n",
    "                     sparse=True\n",
    "                    ).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use the loader to build a loader\n",
    "loader = model.loader(batch_size=128, shuffle=True, num_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "1 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "2 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "3 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "4 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "5 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "6 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "7 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "8 torch.Size([1536, 3]) torch.Size([1536, 3])\n",
      "9 torch.Size([1536, 3]) torch.Size([1536, 3])\n"
     ]
    }
   ],
   "source": [
    "for idx, (pos_rw, neg_rw) in enumerate(loader):\n",
    "    if idx == 10: break\n",
    "    print(idx, pos_rw.shape, neg_rw.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1077290, 3179853, 4891621]) tensor([1077290, 3409655, 4890161])\n"
     ]
    }
   ],
   "source": [
    "print(pos_rw[0],neg_rw[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inizialize optimizer\n",
    "optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch, log_steps=500, eval_steps=1000):\n",
    "    model.train()\n",
    "\n",
    "    total_loss = 0\n",
    "    for i, (pos_rw, neg_rw) in enumerate(loader):\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        if (i + 1) % log_steps == 0:\n",
    "            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '\n",
    "                   f'Loss: {total_loss / log_steps:.4f}'))\n",
    "            total_loss = 0\n",
    "\n",
    "        if (i + 1) % eval_steps == 0:\n",
    "            acc = test()\n",
    "            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '\n",
    "                   f'Acc: {acc:.4f}'))\n",
    "\n",
    "@torch.no_grad()\n",
    "def test(train_ratio=0.1):\n",
    "    model.eval()\n",
    "\n",
    "    z = model('author', batch=data.y_index_dict['author'])\n",
    "    y = data.y_dict['author']\n",
    "\n",
    "    perm = torch.randperm(z.size(0))\n",
    "    train_perm = perm[:int(z.size(0) * train_ratio)]\n",
    "    test_perm = perm[int(z.size(0) * train_ratio):]\n",
    "\n",
    "    return model.test(z[train_perm], y[train_perm], z[test_perm],\n",
    "                      y[test_perm], max_iter=150)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 2):\n",
    "    train(epoch)\n",
    "    acc = test()\n",
    "    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = MetaPath2Vec(data.edge_index_dict, \n",
    "                     embedding_dim=128,\n",
    "                     metapath=metapath,\n",
    "                     walk_length=5, \n",
    "                     context_size=3,\n",
    "                     walks_per_node=3,\n",
    "                     num_negative_samples=1,\n",
    "                     sparse=True\n",
    "                    ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 1.1124, -0.2817,  0.5447, -0.4989, -0.6020], grad_fn=<SliceBackward>)\n"
     ]
    }
   ],
   "source": [
    "print(loaded_model.embedding.weight[1][:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "CUDA out of memory. Tried to allocate 2.33 GiB (GPU 0; 1.96 GiB total capacity; 0 bytes already allocated; 1.12 GiB free; 0 bytes reserved in total by PyTorch)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-35-875aa950260f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# load the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mloaded_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"mymodel\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    591\u001b[0m                     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    592\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_zipfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 593\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0m_legacy_load\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    594\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    595\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_legacy_load\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m    771\u001b[0m     \u001b[0munpickler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnpickler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    772\u001b[0m     \u001b[0munpickler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpersistent_load\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 773\u001b[0;31m     \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    774\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    775\u001b[0m     \u001b[0mdeserialized_storage_keys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mpersistent_load\u001b[0;34m(saved_id)\u001b[0m\n\u001b[1;32m    727\u001b[0m                 \u001b[0mobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    728\u001b[0m                 \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_torch_load_uninitialized\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 729\u001b[0;31m                 \u001b[0mdeserialized_objects\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mroot_key\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrestore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    730\u001b[0m             \u001b[0mstorage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeserialized_objects\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mroot_key\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    731\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mview_metadata\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mdefault_restore_location\u001b[0;34m(storage, location)\u001b[0m\n\u001b[1;32m    176\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdefault_restore_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    177\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_package_registry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 178\u001b[0;31m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstorage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    179\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    180\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36m_cuda_deserialize\u001b[0;34m(obj, location)\u001b[0m\n\u001b[1;32m    156\u001b[0m             \u001b[0mstorage_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    157\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mstorage_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    159\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    160\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/torch/cuda/__init__.py\u001b[0m in \u001b[0;36m_lazy_new\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    431\u001b[0m     \u001b[0;31m# We may need to call lazy init again if we are a forked child\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    432\u001b[0m     \u001b[0;31m# del _CudaBase.__new__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 433\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_CudaBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__new__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    434\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    435\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 2.33 GiB (GPU 0; 1.96 GiB total capacity; 0 bytes already allocated; 1.12 GiB free; 0 bytes reserved in total by PyTorch)"
     ]
    }
   ],
   "source": [
    "# load the model\n",
    "loaded_model.load_state_dict(torch.load(\"mymodel\").detach().cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# move the model to cpu\n",
    "file = torch.load('mymodel', map_location=lambda storage, loc: storage)\n",
    "loaded_model.load_state_dict(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.6251,  0.3285,  0.5691,  0.1887, -0.0316], grad_fn=<SliceBackward>)\n"
     ]
    }
   ],
   "source": [
    "print(loaded_model.embedding.weight[1][:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_venue = loaded_model('venue', batch=data.y_index_dict['venue']).detach().numpy()\n",
    "z_auth = loaded_model('author', batch=data.y_index_dict['author']).detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_venue = z_venue[0:100]\n",
    "z_auth = z_auth[0:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "import umap\n",
    "\n",
    "embedder = umap.UMAP().fit(data,y)\n",
    "\n",
    "z_venue_2d = umap.UMAP().fit_transform(z_venue)\n",
    "z_auth_2d = umap.UMAP().fit_transform(z_auth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAF1CAYAAADx1LGMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAykklEQVR4nO3de5SddX3v8feXkAvhGiAJISGEAAEkhAQGBDwrplAuWkBNQeOyLjURPKf2oGcdOV7OoYzLHqvFnkW11jYK0qMcCgZQWm1BWjFaQmGCcQhOICaEMAkhwy2A5Ep+549n78wzO3tmP/fr57UWi8zMnr1/e8+zv/v7fH/f3+8x5xwiIlI+B+Q9ABERiUYBXESkpBTARURKSgFcRKSkFMBFREpKAVxEpKQUwKW2zMyZ2UkJ3ddDZvbxYX42o/FYBza+/mcz+0gSjyv1pgAuqTCzsWZ2i5k9a2avm9mvzOxdvp8vMLO9ZvZG479+M7vLzM7Jc9xZcM69yzn393mPQ8pPAVzSciDwHPBO4HDgBuAuM5vhu81m59whwKHAecAa4BdmdlHGYxUpJQVwSYVz7nfOuW7n3Abn3F7n3D8BzwBnt7mtc871O+f+FPgO8NXh7tfMzjOzh83sVTP7tZkt8P3sITP7s8bP3zCzfzSzo8zsdjN7zcwea/kAAXi3ma03sxfN7CYzO8B3f4vNrM/MXjGz+83seN/PLjazNWa2zcz+GjDfz0aZ2dca97ke+IOW57Cv3GJmHzWzXzZu/4qZPdNypnKCmS1vnMU8aGbfNLPvj/zqS10ogEsmzGwyMAt4ssNN7wHOMrOD29zHVODHwJ8BRwKfAe42s4m+my0CPgxMBU4EVgDfbdy+D7ix5W7fB3QBZwHvARY3Huu9wBeAhcBE4BfAHY2fHQ3cDfwv4GhgHfAO331eA1wOzGvc91UdnvPbgaca9/UXwC1m1vxA+H/Ao8BRQHfjuYkACuCSATMbDdwO/L1zbk2Hm2/Gy2aPaPOzPwJ+4pz7SSOr/ynQA7zbd5vvOufWOee2Af8MrHPOPeic2wP8AC+o+n3VOfeyc24jcDPwwcb3PwH8uXOur/G7XwbmNrLwdwO/cc4tc87tbvzeFt99vh+42Tn3nHPuZeDPOzznZ51z33bOvQX8PTAFmGxm04FzgD91zu1yzv0SuK/DfUmNKIBLqholie8Bu4A/CfArUwEHvNrmZ8cDVzfKJ6+a2avAf8ILeE0v+P69vc3Xh7Tc53O+fz8LHOt7rL/yPc7LeB8sUxu32fd7ztsRzn8/x7L//Y5kX/B3zr3Z+Ochjft52fe91vFKzR2Y9wCkuhplgFuAycC7G9lqJ+8DHnfO/a7Nz54DvuecuybBYR7HYFlnOt4ZQPOx/rdz7vbWXzCzkxu/1/za/F8Dz7d8PT3i2J4HjjSz8b4gftxIvyD1ogxc0vQt4DTgCufc9uFuZJ6pZnYj8HG82nM73weuMLNLGxOF4xrtiNNijPF6M5tgZscBnwLubHz/b4HPm9npjTEebmZXN372Y+B0M1vY6O2+DjjGd593AdeZ2TQzmwB8LsrAnHPP4pWIus1sjJmdD1wR5b6kmhTAJRWNWvEngLnAFl+/94d8NzvWzN4A3gAeA84AFjjnHmh3n8655/AmGr8ADOBlydcT7zj+EbASWIUXmG9pPNa9eN0w/2BmrwGrgXc1fvYicDXwFeAl4GTg3333+W3gfuDXwON4E7NRfQg4v/E4f4b3AbMzxv1JhZgu6CBSHmZ2J7DGOdfaTSM1pAxcpMDM7BwzO9HMDjCzy/DOQH6Y87CkIDSJKVJsx+CVYI4C+oH/4pz7Vb5DkqJQCUVEpKRUQhERKSkFcBGRksq0Bn700Ue7GTNmZPmQIiKlt3LlyhedcxNbv59pAJ8xYwY9PT1ZPqSISOmZWdvtGFRCEREpKQVwEZGSUgAXESkpLeQRkdzs3r2b/v5+duzYkfdQCmHcuHFMmzaN0aNHB7q9AriI5Ka/v59DDz2UGTNmMHgRonpyzvHSSy/R39/PCSecEOh3VEIRkdzs2LGDo446qvbBG8DMOOqoo0KdjSiAi0iuFLwHhX0tFMBFRAL64Q9/yG9+85t9Xy9YsCDXtS0K4CIiAbUG8Dj27NkT+z4UwEumtxe6u2HxYu//vb15j0gkQym8Ad773vdy9tlnc/rpp7N06VIADjlk8NrXy5Yt46Mf/SgPP/ww9913H9dffz1z585l3bp1APzgBz/g3HPPZdasWfziF78AvNr+xz72Mc444wzmzZvHz372MwBuu+02rr76aq644gouueSS2GNXF0qJ9PbC174GEybAtGnwyive15/5DMyZk/foRFKW0hvg1ltv5cgjj2T79u2cc845/OEf/mHb211wwQVceeWVXH755Vx11VX7vr9nzx4effRRfvKTn/DFL36RBx98kG9+85sAPPHEE6xZs4ZLLrmEp59+GoAVK1bQ29vLkUceGXnMTcrAS+See7xjd8IEOOCAwX/fE+eKiyJlkdIb4Otf/zpnnnkm5513Hs899xxr164N9fsLFy4E4Oyzz2bDhg0A/PKXv+TDH/4wAKeeeirHH3/8vgB+8cUXJxK8QRl4qWzc6CUefocf7n1fpPJSeAM89NBDPPjgg6xYsYLx48ezYMECduzYMaQbpFNb39ixYwEYNWrUvrr2SBfKOfjggyOPt5Uy8BKZPh22bRv6vW3bvO+LVF4Kb4Bt27YxYcIExo8fz5o1a3jkkUcAmDx5Mn19fezdu5d777133+0PPfRQXn/99Y73O3/+fG6//XYAnn76aTZu3Mgpp5wSeZzDUQAvkYULvbLfK6/A3r2D/26cwYlUWwpvgMsuu4w9e/YwZ84cbrjhBs477zwAvvKVr3D55Zdz4YUXMmXKlH23X7RoETfddBPz5s3bN4nZzh//8R/z1ltvccYZZ/CBD3yA2267bV+mnqRMr4nZ1dXltB94PL29Xslv40Yv8Vi4UBOYUl59fX2cdtppwX+hBm+Adq+Jma10znW13lY18JKZM6dyx6tIcHoDDKESiohISSmAi4iUlEooOahBGU9EMqAMPGPNxWSvvDJ0MZmWxItIWMrAM+ZfTAaD/7/nnvBZeDOTX7UKXn0VjjgC5s4dmtEr2xepLmXgGdu40Vs85hdlMVkzk3/6aVi/3gvg69fD2rWDGb2yfZFqUwDPWFKLyZqZ/ObNcNBBXvZ90EGwadPg9hDaO0Wk2hTAM5bUYrJmJr9tG4wb531v3Djv62ZGn1S2L1IUSe8m+9nPfpa/+Zu/2fd1d3c3f/mXf8lNN93EOeecw5w5c7jxxhsB2LBhA6eddhrXXHMNp59+Opdccgnbt28Hhl7Y4cUXX2TGjBkAvPXWW1x//fX77uvv/u7v4g24hQJ4xubM8Xa/nDAB+vu9/0fZDbOZyR9+ODT32tmxYzCoT5+uvVOkWtIoCS5atIg777xz39d33XUXEydOZO3atTz66KOsWrWKlStXsnz5cgDWrl3LJz/5SZ588kmOOOII7r777hHv/5ZbbuHwww/nscce47HHHuPb3/42zzzzTPQBt9AkZg6SWEy2cKF38B57LKxeDTt3ehn9SSd5B/aSJd7tvvY17//NwO7/mUiZJNkA0DRv3jy2bt3K5s2bGRgYYMKECfT29vLAAw8wb948AN544w3Wrl3L9OnTOeGEE5g7dy4wdPvY4TzwwAP09vaybNkywNs8a+3atYGvOt+JAnhJNTP5e+6BN98c7EI5+eShnSbN2zS7UJYsUReKlFNa2ylfddVVLFu2jC1btrBo0SI2bNjA5z//eT7xiU8Mud2GDRuGbEg1atSofSWUAw88kL179wJDt591zvGNb3yDSy+9NN4gh6EAXmJBMnltHSFVMX26dwbZzLwhmZLgokWLuOaaa3jxxRf5+c9/zhNPPMENN9zAhz70IQ455BA2bdrE6NGjR7yPGTNmsHLlSs4999x92TbApZdeyre+9S0uvPBCRo8ezdNPP83UqVMT2xNcATwDRenFLso4RKJolg0h2ZLg6aefzuuvv87UqVOZMmUKU6ZMoa+vj/PPPx/wro/5/e9/n1GjRg17H5/5zGd4//vfz/e+9z0uvPDCfd//+Mc/zoYNGzjrrLNwzjFx4kR++MMfxhuwj7aTTZn/Mn7+gy7r61gWZRwifmG3k61DEqLtZAskjYmXMo9DJA6VBIdSG2HKitKLXZRxiEhyFMBTVpRe7KKMQ0SSowCesigrL5NebRZ1HCJZyHIerujCvhYdA7iZ3WpmW81ste97N5nZGjPrNbN7zeyI8EOth7ArL9PagCqpFaAiSRo3bhwvvfSSgjhe8H7ppZcY19wbI4COXShmNh94A/i/zrnZje9dAvybc26PmX218eCf7fRgdexCCau7e/9e1+bX3d15jUokHbt376a/v3/I4pc6GzduHNOmTduv7zxyF4pzbrmZzWj53gO+Lx8Broo2XGmV1mozkSIaPXp0YsvK6yiJNsLFwJ0db1UxafWjNleb7dwJa9Z4E41jxkBjWwYRkX1iTWKa2f8E9gC3j3Cba82sx8x6BgYG4jxcYaR5oYSFC2HdOvj5z709TkaPhtde8/b51oUYRMQvcgA3s48AlwMfciMU0p1zS51zXc65rokTJ0Z9uEJJ80IJc+bAccfBYYfB7t0wfjwsWAAzZ+pCDCIyVKQSipldBnwWeKdz7s1kh1R8adepd+6ESy/1Phya9u5VHVxEhgrSRngHsAI4xcz6zWwJ8NfAocBPzWyVmf1tyuMslLQXxWjRjYgE0TGAO+c+6Jyb4pwb7Zyb5py7xTl3knPuOOfc3MZ//zmLwRZF2otitOhGRILQSswI0l4Uo0U3IhKEdiOMKOld0dq1JWrhjoiMRBl4AaTZligi1aUMvADi7tVdh03uRWR/CuDDyDIoxmlL9F9px5+9q2YuUn0qobSRdUkjTttgmouKRKTYFMDbyDooxmkb1JV2ROpLAbyNrINinLZBLfoRqS/VwNto7gjo35M77aAYtS1x4UKvvANDrza/ZEmy4xOR4lEG3kaZVkJq0Y9IfXW8Ik+SynRFHrXmiUhRRL4iT10lvdJSRCRpCuApUQYvImlTDTwFWhovIllQAE+BFteISBZUQklBEa4srxKOSPUpA09B3otrVMIRqQcF8BTk3UeuEo5IPdSmhJJlSaG5uMb/eEuWZFfCKEIJR0TSV4sAnseWq3n2keexFYCIZK8WJZS6lRTyLuGISDZqkYEnUVIoU1dH3iUcEclGLQJ43JJCGa96o60ARKqvFiWUuCWFupVgRKQcahHA4265qqveiEgR1aKEAvFKCll2dZSp1i4i+apNAI8jravetAbr2bPhvvvKVWsXkfzUooQSVxpXvWm33P1LX4I9e1RrF5FglIEHlHRXh39iFLz/794NmzbBrFmDt1OtXUSGU6kAHrd+nGX9uV1v+sSJMDAw9HtaQSkiw6lMCSXuDnxZ7+DXbsfCadNg9GitoBSRYCoTwOP2amfd692uN33UKLjhBl1hXkSCqUwJJe5y+ax38BtpuftVV6XzmCJSLR0DuJndClwObHXOzW5872qgGzgNONc515PmIIOI26s93O+PGQPd3enUxbXcXUTiCFJCuQ24rOV7q4GFwPKkBxRV3OXy7X5/3TqvK0RXthGRIuoYwJ1zy4GXW77X55x7KrVRRRC3V7vd7x93HMycqb5sESmm1GvgZnYtcC3A9JT74eKWJFp/f/Fir7XPT33ZIlIUqQdw59xSYClAV1eXS/vx/OL2devKNiJSZKVuI+zt9SYYFy/2/u+vTSfR160r24hIkZU2gHcK0En0daexB4qISFKCtBHeASwAjjazfuBGvEnNbwATgR+b2Srn3KVpDrRVu71Emt+fMye5vm61+olIUXUM4M65Dw7zo3sTHksonQK06tciUnWlLaG020vEH6BVvxaRqittAO8UoFW/HnmSV0TKz5zLrrOvq6vL9fQkt+re3yY4ZgyYwc6duhQZDE7yTpgw9CpCdfsQE6kCM1vpnOtq/X6pN7NqTjD6g9XEidW+FFnQ3vZOk7wiUn6lLaH4Zb0VbF7C9LZv3Ohl3n5aRSpSLZUI4HUJVmE+qDpN8opI+VUigNclWIX5oFIXjkj1lboG3rRwoVdKgKETdkuW5DuuIMLs1xKkt91/fwcf7E3q9vcPvWCEiFRDJTLwsrYMht2vpVNW3Xp/Y8bAm2/Cpz/ttREW/fUQkXAqkYFDOZe8h+0UGekybFHuT0TKrTIBPCtxt6j1i7Jfy0gfVFlf11NE8lWJEkpWktii1i/pyde6TOaKiEcBPISk+82T7hRR54lIvSiAh5B0v3nSk6/t7u/KK70PGO2HIlI9qoGHkMYWtUlPvvrvz7/FgL/kU4YOHRHpTBl4CGUrUdRliwGRulIAD6Fs/eZ12WJApK5UQgmpTP3muiqRSLUpgGcoyR7yIMq8xYCIdKYSSkaS7iEPomwlHxEJRxl4RvJa5l6mko+IhKMMPCOaUBSRpCmAZ0TL3EUkaSqhZCTqhGIWE59ZT66KSDKUgUfU2+stTQ+6RD3KhGIWE595TK6KSDIqnYGnlVlGXaIedkIxi4lP7SEuUl6VzcDTzCyzWqKexcSnJldFyquyGXjSmaU/m3/8cXj724f+PI2gl8VKSq3WFCmvymbgw2WWq1aFq13D/tn82LGwfDls2TJ4mzSCXhabZ5Vtgy4RGVTZAN6ube+3v4VnnglfVmktmZx1lvf9X/0q3aCXxUpKrdYUKa/KllDate09+STMnh2+rNJ6rcnJk2H+fHj0US/otV5cOElZrKTUak2RcqpsAG93BfcTToATTxx6uyC163Z14nHj4D3v8cowIn7qq5esdAzgZnYrcDmw1Tk3u/G9I4E7gRnABuD9zrlX0htmNK2ZZXd3tAk77eoXTR0Dma6CJFkKUgO/Dbis5XufA/7VOXcy8K+Nrwsv6oRdnevEYRcs+X+vjguE/PMlAwPw619DTw9cd131n7tkz5xznW9kNgP4J18G/hSwwDn3vJlNAR5yzp3S6X66urpcT09PzCHHU8esMCp/Nuk/8wjy4dXubKf5dZXLTosXex9YAwPw8MNeqW3sWHjxRejqqs8HvyTLzFY657pavx+1Bj7ZOfc8QCOIT4o1ugxpwi64OL30rRO/UM0FQq0JwZgx3gddX58XvA86CLZvh0mTBhd76fiTpKTeRmhm15pZj5n1DAwMpP1wkqA4qzTrsPtiuzLRpk2wbh1s3epl3tu3w44dcOqp1fwAk3xFzcBfMLMpvhLK1uFu6JxbCiwFr4QS8fFSpbJKe3FWac6eDV/6EuzeDRMnegFu1KiRJ36L/HdoN7Z2ZygzZ8KuXfDGG14QnzQJ5s2DY47xXssqfYBJ/qJm4PcBH2n8+yPAj5IZTvbqOtkWRNRJ395euO8+OP10L3gPDMDq1XDllcMH5CL/HYYb26pV7c9Qdu6Er3/dq3mfeaYXxLXCVdLQMYCb2R3ACuAUM+s3syXAV4CLzWwtcHHj61LKamOqMorafdN8TWfNgt/7PXj/+2HBAi+Id/qdIv4dhhvbq68OXyaqc+eSZKdjCcU598FhfnRRwmPJRV0m26KKMukb5TUt8t9huLEdcYSXVTe/bl0foAlzSVtlV2IGVabd+IpcI/aL8poW+e8w3Njmzh2shTf/JmltqSDSTmU3swqqLLvxFblG3CrKa1rkv0ORxyb1VvsAXpZaZZFrxK2ivKZF/jsMNzYoz4eqVFNpSyhJlhPSqlUmOcYi14jbifKaFrlm3G5s3d26HJ3kq5QZeBnKCUmPsQ4LY8pGl6OTvJUygJehnJD0GFWHLR59qEreShnAy5D5JD3GJGvEUXcYlKH0oSp5K2UNvMgtZ01pjDGJGrH2q05Ou4uGqI1QslTKAF6GCywUdYxRdxgsSw961oo88SrVV8oSSpFbzppax7hrF4wfDzffnG/ZIkpppwyTxkWlcpWkqZQZOGST+cTNOptj9JctmhsbJVm2CDPOKKWdOPuC15nKVZK2UmbgWUgy60yzaybsOKNMvJVh0rhoenu9y6j19HiXVRsYKGa3lJRb4QN4XqegSQbdNANg2HFGKT+pXS6c5ofq1q1w9NHeRR0efhheeEEffJKsQpdQkjgFjVoGSXLlY5pdM1HGGbb8VNQJ2aJqfqhOmuQF74MO8r7f1+ddck0ffJKUQmfgcbPgOGWQJLPONPuFs8iOyzBpnLYwZ4LNM65TT/Uup7Z9u3d5ta1b1ScuySp0Bh4lu/Rn3OvXw9Sp0Sbfksw60+wXzio7rnO7XNgzweYZ1zHHwPnnw5o1g5dXq9sHn6Sr0AE8bOmh9Y32yCPw8stw2GEwebJ3m6BlkKSDbloBMIlxqsd7ZGG7cPwfqpMmedn3K68oeEvyCh3Aw2aXrW+0SZO8y1719Q0G8CpOvsX5cFCrW2dhzwS1QlOyUugAHvaN0PpGO/VUb/Z/61av9hymvFCXwKYe786iTELXueQk2Sl0AIdwb4TWN9oxx8Ds2bB5szf5FiYTKmJgC1LqCFsOKds+43lQF44UVaG7UMJq1+1x4IHw9a/Drbd63QNBg2/RFq8E6aiJ0nXT7GLZsgUeegh+9CO4/36v3U086sKRoip8Bh5GkrXHdqfN69bBpk1eK1nWk31BzgiinDUsXAhf+IL33A49FEaPhtde855nb285glQWk7AqiUgRVSoDB+9N1t0dPuNu1ZrNr10LK1bAscfms6FTkDOCKGcNc+bAccd5nTq7d3sbbi1YADNnlmPJtzbakjqrVAaepNZsftMmOO88mDXL+3m77DbNTDDIRFrUFZ87d8Kll3qLpZr27i1OHXyk17WIcxUiWalcBp4kfzY/cyacdNLQn/uz27QzwSCrOaOu+CzyXiedXteizVWIZEkBPKBOQS7t63QGmUiLOtlW5EuDdXpdi/zhI5I2lVAC6tRKlkU73nATaUmUbg4+GJYvB+e8UlFRuiw6va5q8ZM6UwYeUKfsNq9MMG7ppvn7Y8bAFVfAO98Jb76Z7pjD6PS6qsVP6kwZeAgjtZKllQl2yq7jTuLlNQkY9KwhyOuqFj+pK2XgCUkjEwySXcedxMtjEjDMWYMybJHhKQNPUGsm2NxDOmptOkh2HPdiEWlebGI4YbN+ZdgJ0JaTlaQMPCVJtBUGyY7jdpDk0YGi1r+MabVTZSmApySJtsIgE6NxSwx5lCjU+pextHtcJTexSihm9ingGsCAbzvnbk5iUFWQRFth0InRuCWGrEsUav3LWBIHo0owhRQ5Azez2XjB+1zgTOByMzs5qYGVXRJZZlUn8Kr6vHI33IU74x6MKsEUVpwM/DTgEefcmwBm9nPgfcBfJDGwrCWdYCSVZfr3/Ni4cfCst+zBThOTCRvpCiRxD8Y4vabK3FMVpwa+GphvZkeZ2Xjg3cBxrTcys2vNrMfMegYGBmI8XHrSSDCSyjKV/NREmMvetzNSnTvuwRh11lkHb+oiZ+DOuT4z+yrwU+AN4NfAnja3WwosBejq6nJRHy9NaS1mSSLL1G57NZDE9fs61bnjHIxRe0118KYuVheKc+4W59xZzrn5wMvA2mSGla0it7UVeWySkKxalqKK2muqgzd1sQK4mU1q/H86sBC4I4lBZa3IbW1FHhvEP/MXkgl0aTb0Ry3BFP3grYC4feB3m9lvgH8EPumceyWBMWWuyNupFnlsKnEmpAwtS1EudVXkg7cizLnsytJdXV2up6cns8cLo8iT5XHHltZz6+7evzTa/Lq7O/7914a/Bu7vEqlCb2WR31glYmYrnXNd+31fAbza0owNixd7mXfrpdj6+71ELc6Ya/eer+WTlqCGC+DazKri0mwESGMjrCQaMkpJjfESgfZCqbg0GwHSKHFq245haLZY2lAGXnFpbhfbnDfzn/kvWRIvkczi0nS5iFMiyfu0pN3YQSWfAlAAr7i0N45K+sw/j/3JUxc3AOe5IKbd2L/wBTCDmTNrVucqHpVQKq5sG0dVsvMsbl0o7QUxI5Vn2o19YAC2blWdqwCUgZdUmDPyMs2PpVGWyV3culCapyWdzg7ajX3nzv3vpxJ1rvJRAC+IMAE575Jo2sr0gRNI3ACcZh2sU3mm3djHjt3/fvKoc6n1UiWUIgi7olGdGiUTty6UZh2sU3mm3dgnToRJk/Ktc2kZMKAMvBDCzlFVtlOjqpKoC6V1WtLp7KDd2L/8Ze9neda5tNMhoABeCGEDciU7NaquqHWhIOWZ4cae5/NRFgOohFIIYfcyqmSnhuSjbG1KTdrpEFAGXghh56iCnJFrfkcCS+rsIOxBF+cg7fSmqckbQJtZFUSSx1uVN7eTggp70CVxkA73pqngG0CbWRVckiVSze9I5sIedEkcpMO9aWr0BlANvIJ0JSvJXNiDLs2DtEZvAGXgFaQulZTUpK4aSdiDLs2DtEZvAGXgFaQulRRo4cjIwh50aR6kNXoDaBKzopQsJqy7W9eP6yTLLpS4YynZG0SXVBOJI63rx4VRsqBTWCXsUhkugKuEIhJE68KRF16A+++Hxx8Pd4WcqFfWUQknORXaTEgBXCQIf131+efhoYfgtdfg7W8PHkzjBOEKBZ3cVahLRQFcJAj/kvNHH4XDDoN3vhOmTPG+99ZbcN11I2fWcYJwhYJO7iq0DF9thCJBNReONDdSatbDX3gBnngC9uyB+fOH36A9zgZMcVvjVD8flPZ1BjOkDFwkrNYMrq/PC+aTJo2cWcfJ/OK0xql+PlRZN/BqQxm4SFitGdzWrXDggXDqqYO3aZdZx8n8Ou1gNlKGHXZpuf++xozxLmC8c2e1Mveibu8bktoIRaLwB7n16+HYY2HWrMGfD9cjnkYpo1NbXJgWSP997dgBy5d7358/H8aNK3y7XVVpMyuRJPkzOH+JolNmnUbmF+W6lsOVbvz39dBD3mQtwFNPwYIFQ+9XcqcauEhceddUo1zXcrj6uf++tm3zsu5x4wZr9+p8KRRl4DIsNS6EkGdNNcp1LYe7hqX/vg4/HLZv977vD+olbLerKgVwactfCvU3Lqj8WUBxrms50n2dcspgDXzu3MHMvcjtdjXLOlRCkba08K9Ekizh+O9r925vsdKCBd6/i95uV8N2yVgZuJn9N+DjgAOeAD7mnNuRxMAkX7rod8kkWcIpa4tdja7E0xQ5AzezqcB1QJdzbjYwCliU1MAkXxVabSx1UcPtBuLWwA8EDjKz3cB4YHP8IUkRVGi1sQRRhdpxja7E0xQ5A3fObQK+BmwEnge2OeceSGpgkq+8O+MkQ1nVjqNupRv0fmp0JZ6myCsxzWwCcDfwAeBV4AfAMufc91tudy1wLcD06dPPfvbZZ+OMV6R+0s6Ou7vTv9pQUhdR6HQ/VTiTaCONlZi/DzzjnBtoPMA9wAXAkADunFsKLAVvKX2MxxOpnyz6ObOYsU5qgrHT/ZR1AjaiOAF8I3CemY0HtgMXAdroRCSsJDeiiiKL2nFSHxJqjxoiTg38P4BlwON4LYQH0Mi0RSSgTvXnLDorsqgdJ9XWpPaoIWJ1oTjnbgRuTGgskpKKlgWrIcmNqKIKs9Q+qqTammbPhi99yVtYNHEiTJ3qbeXbej81Oei1nWzFlfAC3OUUNWB02uq1Sn/AuEG1+Vq89Zb3+gwMwOjRcMMNcNVV+9+uCq9Zg7aTrakaLk7LXpyJxiQ3oiq6uBOM/oP55JO9773yCqxePTSA1+igVwCvOM35ZCBOwEhyI6o8ZVGyCHowB71dBcos2syq4jTnk4E4E41VWDGV1UKgoAdzkNtVZOMrZeAV1HpJw02bYOZMLYlPTdyJxiAZdpGzxaxKFkEnQoPcriJlFgXwkgj6/m0tx27bBs7Brl1eglfmEmomogTKtDeOibuYJ6ngP9z9ZFWnCzofEOR2FaktqgulBMJMqmexKrqy4nQvpJkhx/mjZrGE/Z57ynfQleyNoi6UEgtztleRxCIfcU6r05xojPNH7fScgn7wjHQ/Zdy6soxjbkOTmCUQZo5Mk5YxFHU/6Th/1JGeU5iJvJHup4wTsWUccxvKwEsgzBxZFuXYos6lxVbU/aTj/FFHek5hzjiC9KuX7UAIOuYCH/TKwEsgzFYVaSYWFem8Gl5R95OO80cd6TmFOeMo6muTtoIf9JrELIkiJAElm/eJJusXOovH8z/G2LGDbUnr13t7iTRXNcLIf9AiHIRZK8hBr0nMkivCGWotJkizfKGz2OsbBp+T//EmTYIdO2DFCu82J57YuTRThIMwawU/6BXAJbCilohLK+vFJK2PN2uW9/9Nm7zMXIsE9lfwg14BXAKrSOdVcYTJ7pIoX7R7vJNOgnHjvJ0PZX8FP+gVwCskzGrNKLGgShvjhZJW7TdodpdUqSXJbLIu9fCCH/SaxKyIoAvuKrhVcrrSfMGC3ndSE2lZXVg4CXX5gAhouElMtRFWhL+8ecABg/++555ot5OGNF+woO2BSS0wSqrHNO2DqOCte0WiEkpFJL1VsjSk/YIF6exIsvSRRCdJ2q9JRXYKzIIy8IpIcqtk8SnCC1a0RTRpvyZF3dKggBTAKyLoe7xosaDwivCCFW3fjrRfkyJ8aJaEJjErJO0ulLTup/Bq80RDSPM1SXKytSJ/t+EmMRXAJRJ1s0iqkrqCfUUOUC2ll0RpnikHFcooO0ryCvZQ2QNUAVwiUTdLG1mVFdLcN6Uq0j5AC/JhqklMiUTzTC3S7l1WA384aR6gBepTVwCXSIrQnFEoaQfYNFrrenu9lZyLF3v/r9JCmTQP0AJ9mCqASyRF62zbJ6+glHbvctIZZYGyyFSkeYAWqE9dNXCJrHDbQ+dZJ05729Gkd8ULM8lXkHpvaGkdoAXaYlYZuFRHnqe2adeUks4og2aRVc/UoyhQ/VAZuFRHnq0x/m1HV62CV1+FI44Y/PBIIhNMMqMMmkXWpB0vlAJtMasALtWR96lt8w28fj0cf7z34VHUdr+gJZmkPhSzLMNk8VgFqR+qhCLVUYRT2wJ1KIwoaEkmicnTLMswNSv5RM7AzewU4E7ft2YCf+qcuznuoEQiKcKpbZlWOAXJIpOYPM2yDFOzkk/kAO6cewqYC2Bmo4BNwL3JDEuyksbZZq5NC3mf2uZdxklaEh+KQT7UkjpoyvQBmoCkSigXAeucc88mdH+SgTTONmt2Bru/IpRxkjZnjtdTf+ut3v/DBtZOZZgkD5qaLRFOKoAvAu5o9wMzu9bMesysZ2BgIKGHkySkUa4tSwk4NYVd4ZSjTh9qSR40RfwATXFxWeztZM1sDLAZON0598JIt9V2ssWyeLGX8Bzg+xjfu9eLO7feWpz7rLSyLpIJa6TnmfRBU6TXNKFtbdPcTvZdwOOdgrcUT6dybZT3QdVKwKmq0w6DI81NJH3Q5D0P4pfypGoSJZQPMkz5RIptpLPNqGXJwpzBlmGjpiyu7l701wAKdNCkIOV9U2IFcDMbD1wM1KXCWSkjlWujxpZClIDLMpOa5pu7LK8BFOSgSUnKk6qxSijOuTeBoxIZieRiuLPNON1YuZ/BlqUXOM16U1leg9Y63ac/XazxxZX0JmQttJRegP3fR2PHesdaKWvZefQCR5kwSPPNXYZ+6DrMAaS8uEwBXNq+j557Dsxg5sxUEod0ZT2TGjUQpfnmLsNsct5nCVl1q6R4SqoALm3fRyeeCDt3ev/OecO18JLMbIO8yYMEouHuJ603d9avQRR5niVUJPtPaiGPlNhwc2m7dsVbgJebpCbFgk4EdpqMzGNCMevXIIo8V01WZMWZMnApxdl2aElktkFP8Tu9gHmVCrJ8DaJIeYJvRGWYIwhAGbhUug03lqBtfp1ewAJdQzG0NMeeZ/tgRfZMUQYusefSirRyOVFBT006vYBlPsVJe+x59ZyWYY4ggNh7oYShvVCqJ6GtHoopqSdX5BepU/Ap8tjjSiLwZvT6DLcXigJ4zYU9hltvv2WL1zPuT9CaCVt3d+rDjy7oE08qu8oqSwvzOEGDT2VPsVpEeZ7d3fufoaTwBlAAl/2ETR7a3f7HP4aLLoIpUwZvV/jdB6uaVYZ9XhkFn1KIekxktP1mmrsRSon4k4z16+HYY4M3GLRrSDjqKO8i7P4AXvjybt4LSNIS9nll0YkRJauNk/FH/d2ox0TO8xvqQqmR1pberVth9WqvDNI00vu3XUPC3Lnw0ksl62Apc1fISMI+r7Q7MaL0kMfpO4/zu1GPiZxbuBTAa6R17cKkSd7/16wZvM1I79927/dx4+Dii0u2kVxFWsj2E/Z5pR18oiyWibPAJs7vRj0mct5JUSWUGmk9Yz7tNPj3f/cy8b17O3dSDdd5VfiA3SrPBSRpCvu8Ut5oKVKJJk5ZJ87vxjkmctx+Uxl4jbQmGZMnwxlneJl4kOShMts2V+aJtIjyvPwXLF640AvmSV0AIkpWG+fsKM7vlvSYUBdKjVS1+UISkMbBEeU+44yjwge42ggFqE9Lr4TQ2wvXXefV0iZNglNPhWOOSaalsCxdKAWnAC4i+2tmrT09cPTR3h7CO3bA+ecP1tYK29BfH+oDl/qpaDaWqGbnxqRJsH07HHSQ9/01a7wltlE7c/TaZ0KTmFJNZbqob56a/c+nneZl3tu3e4F769boLYV67TOjAC7VVJEN+1PX7NyYPBkuuMDLwF980cvIo07+6bXPjEooUk2tPcFbtkBfH2ze7H2d5il9mcoH/v7niRNhzJj4nRsVuVhCGSgDl2ry9wRv2QIrVnhfH3tsuqf0ZSsfpNH/XNWVrgWkDFyqyZ9Z9vWBGTgHb3tbuptXlXGjrKRXElZ1pWsBKQOXavJnlps3e4Hkggu8Wi+kd0qf9kZZvb1eX3ZSqyXTUNJVjWWkDFyqy59ZZrXlZ5rbi/pXGvrLM0UMjjnuD1InCuBSfVme0kd5rKCTnkHLM2WaRJVYVEKR6svylD7sY4WZ9AxSninbJKrEogxc6iHLU/owjxVm0jNIeaaMk6gSmQK4SJ7C9EwHKc9UoQc76xJQiUtOKqGIQH7dHWF6poOUZ8reg511CajkJScFcJE838RhL2vmvwBDd/f+mWLO12iMLetl+CVf9h8rgJvZEWa2zMzWmFmfmZ2f1MBEMpPnmzjpCday92BnfcHpkl/gOm4N/K+Af3HOXWVmY4DxCYxJJFtx6sZJ1E+TnmCNcn951oH9j71+vbcr4qxZgz9PswSUZt9+BiJn4GZ2GDAfuAXAObfLOfdqQuMSyU7UunEepZc0avVRn0cSY2l97KlT4ZFH4OmnsykBtSs5rV/v7Z9T5NWuDXFKKDOBAeC7ZvYrM/uOmR2c0LhEshO1bpx16SWtD4wozyOpsbQ+9skne1cD2rw5mxJQa8lp1y5vz5yxY0sxqRkngB8InAV8yzk3D/gd8LnWG5nZtWbWY2Y9AwMDMR5OJCVR68ZZ10/T+sCI8jySGku7xz7xRJg5c/iJ2qT5J4YnT/YevySTmnFq4P1Av3PuPxpfL6NNAHfOLQWWgndNzBiPJ5KeKHXjrOunafV4R3keSY2laDXokvXRR87AnXNbgOfM7JTGty4CfpPIqETKIOuWvSC1+ih16SjPI6l+86K1PZasjz5uH/h/BW43s15gLvDl2CMSKYusW/Y6BbuodekozyOpwFu0tseifaB0YM5lV9Xo6upyPT09mT2eSOWM1O7X3b1/OaL5dXd3tmMpswI+LzNb6Zzrav2+9kIRKZORavVZ12+ruud3iZ6XltKLVEXJ6rcSnwK4SFWUrH4r8SmAi1RF0SYEJXWqgYtUSYnqtxKfAriIRJNXt0YBu0TyohKKSJbyunBE0vLaQ73kF2BImgK4SFaqFHzy2kO95BdgSJoCuEhWqhR88roQQskvwJA0BXCRrFQp+OTVc65e9yEUwEWyUqXgk1fPuXrdh1AAF8lKlYJPXj3n6nUfQptZiWSpCi1wVXgOJaPNrESKIO2FNmkH12YnzYQJQztpapwF50klFJGqyKJNsUqdNBWgAC5SFVkE1yp10lSAArhIVWQRXKvUSVMBCuAiVZFFcK1SJ00FKICLVEUWwVVtfIWiNkKRKlGLXyWpjVCkDrQfeK2ohCIiUlIK4CIiJaUALiJSUgrgIiIlpQAuIlJSCuAiIiWlAC4iUlIK4CIiJaUALiJSUgrgIiIlleleKGY2ADyb0t0fDbyY0n3HpbFFo7FFV+TxaWzhHe+cm9j6zUwDeJrMrKfdZi9FoLFFo7FFV+TxaWzJUQlFRKSkFMBFREqqSgF8ad4DGIHGFo3GFl2Rx6exJaQyNXARkbqpUgYuIlIrlQngZna1mT1pZnvNrBCzyGZ2mZk9ZWa/NbPP5T0ePzO71cy2mtnqvMfSysyOM7OfmVlf42/6qbzH1GRm48zsUTP7dWNsX8x7TK3MbJSZ/crM/invsfiZ2QYze8LMVplZ4a6taGZHmNkyM1vTOPbOz3tMnVQmgAOrgYXA8rwHAt6bCPgm8C7gbcAHzext+Y5qiNuAy/IexDD2AP/dOXcacB7wyQK9djuBC51zZwJzgcvM7Lx8h7SfTwF9eQ9iGL/nnJtb0Fa9vwL+xTl3KnAmxX0N96lMAHfO9Tnnnsp7HD7nAr91zq13zu0C/gF4T85j2sc5txx4Oe9xtOOce94593jj36/jvZGm5jsqj/O80fhydOO/wkwkmdk04A+A7+Q9ljIxs8OA+cAtAM65Xc65V3MdVACVCeAFNBV4zvd1PwUJQmViZjOAecB/5DyUfRolilXAVuCnzrnCjA24GfgfwN6cx9GOAx4ws5Vmdm3eg2kxExgAvtsoP33HzA7Oe1CdlCqAm9mDZra6zX+FyWx9rM33CpOplYGZHQLcDXzaOfda3uNpcs695ZybC0wDzjWz2TkPCQAzuxzY6pxbmfdYhvEO59xZeGXFT5rZ/LwH5HMgcBbwLefcPOB3QKHmrdo5MO8BhOGc+/28xxBCP3Cc7+tpwOacxlI6ZjYaL3jf7py7J+/xtOOce9XMHsKbSyjCZPA7gCvN7N3AOOAwM/u+c+6Pch4XAM65zY3/bzWze/HKjIWYs8J7v/b7zqaWUYIAXqoMvGQeA042sxPMbAywCLgv5zGVgpkZXi2yzzn3f/Iej5+ZTTSzIxr/Pgj4fWBNroNqcM593jk3zTk3A+94+7eiBG8zO9jMDm3+G7iEYnzoAeCc2wI8Z2anNL51EfCbHIcUSGUCuJm9z8z6gfOBH5vZ/XmOxzm3B/gT4H68Sbi7nHNP5jkmPzO7A1gBnGJm/Wa2JO8x+bwD+DBwYaPlbFUjqyyCKcDPzKwX70P6p865QrXrFdRk4Jdm9mvgUeDHzrl/yXlMrf4rcHvjbzsX+HK+w+lMKzFFREqqMhm4iEjdKICLiJSUAriISEkpgIuIlJQCuIhISSmAi4iUlAK4iEhJKYCLiJTU/wfqzngJPg3GugAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "plt.figure(figsize=(6,6))\n",
    "plt.scatter(z_auth_2d[:,0],z_auth_2d[:,1],color=\"red\",alpha=0.5,label=\"author\")\n",
    "plt.scatter(z_venue_2d[:,0],z_venue_2d[:,1],color=\"blue\",alpha=0.5,label=\"venue\")\n",
    "plt.legend()\n",
    "plt.title(\"2D embedding\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
